{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fab6b7d5",
   "metadata": {},
   "source": [
    "# Demo\n",
    "Examples of how to use this repo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fea74aa3",
   "metadata": {},
   "source": [
    "These are the commands for VPT with `p=50`:\n",
    "\n",
    "VPT-deep\n",
    "```bash\n",
    "    --config-file configs/prompt/*.yaml\n",
    "    MODEL.TRANSFER_TYPE \"prompt\" \\\n",
    "    MODEL.PROMPT.DEEP \"True\" \\\n",
    "    MODEL.PROMPT.NUM_TOKENS \"50\" \\\n",
    "    MODEL.PROMPT.DROPOUT \"0.0\" \n",
    "```\n",
    "   \n",
    "VPT-shallow (we don't use dropout for VPT-shallow)\n",
    "```bash\n",
    "    --config-file configs/prompt/*.yaml\n",
    "    MODEL.TRANSFER_TYPE \"prompt\" \\\n",
    "    MODEL.PROMPT.DEEP \"False\" \\\n",
    "    MODEL.PROMPT.NUM_TOKENS \"50\" \\\n",
    "    MODEL.PROMPT.DROPOUT \"0.0\" \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0825e402",
   "metadata": {},
   "source": [
    "Other transfer protocols presented in the paper:\n",
    "\n",
    "Full\n",
    "```bash\n",
    "    --config-file configs/finetune/*.yaml\n",
    "```\n",
    "\n",
    "Head-oriented methods:\n",
    "\n",
    "- Linear:\n",
    "```bash\n",
    "    --config-file configs/linear/*.yaml\n",
    "```\n",
    "\n",
    "- MLP-3 (3 layer MLP):\n",
    "```bash\n",
    "    --config-file configs/linear/*.yaml \\\n",
    "    MODEL.MLP_NUM \"2\"\n",
    "```\n",
    "\n",
    "- Partial-1:\n",
    "```bash\n",
    "    --config-file configs/finetune/*.yaml \\\n",
    "    MODEL.TRANSFER_TYPE \"partial-1\"\n",
    "```\n",
    "\n",
    "\n",
    "Backbone-oriented methods:\n",
    "\n",
    "- Sidetune:\n",
    "```bash\n",
    "    --config-file configs/linear/*.yaml\n",
    "    MODEL.TRANSFER_TYPE  \"side\" \n",
    "```\n",
    "\n",
    "- Bias: \n",
    "```bash\n",
    "    --config-file configs/finetune/*.yaml \\\n",
    "    MODEL.TRANSFER_TYPE \"tinytl-bias\"\n",
    "```\n",
    "\n",
    "- Adapters with `r=128`:\n",
    "```bash\n",
    "    --config-file configs/finetune/*.yaml\n",
    "    MODEL.ADAPTER.REDUCATION_FACTOR \"128\"\n",
    "    MODEL.TRANSFER_TYPE \"adapter\" \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a755f5a",
   "metadata": {},
   "source": [
    "##  train.py\n",
    "The main script is `train.py`. Note for VTAB data, this script handles the final runs with 1000 training data. See `tune_vtab.py` for the full tune + final runs settings. Here are some examples.\n",
    "\n",
    "Note: it's recommended to directly use terminal for these command."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14c5be87",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "# launch final training with five random seeds for VTAB-dmlab, sun397 and eurosat. The hyperparameters are the same from our paper.\n",
    "model_root=<MODEL_ROOT>\n",
    "data_path=<DATA_PATH>\n",
    "output_dir=<OUTPUT_DIR>\n",
    "        \n",
    "# vtab-structured: dmlab\n",
    "# base_lr = 1.0\n",
    "# lr = base_lr / 256 * cfg.DATA.BATCH_SIZE\n",
    "for seed in \"42\" \"44\" \"82\" \"100\" \"800\"; do\n",
    "    python train.py \\\n",
    "        --config-file configs/prompt/cub.yaml \\\n",
    "        MODEL.TYPE \"vit\" \\\n",
    "        DATA.BATCH_SIZE \"64\" \\\n",
    "        MODEL.PROMPT.NUM_TOKENS \"100\" \\\n",
    "        MODEL.PROMPT.DEEP \"True\" \\\n",
    "        MODEL.PROMPT.DROPOUT \"0.1\" \\\n",
    "        DATA.FEATURE \"sup_vitb16_imagenet21k\" \\\n",
    "        DATA.NAME \"vtab-dmlab\" \\\n",
    "        DATA.NUMBER_CLASSES \"6\" \\\n",
    "        SOLVER.BASE_LR \"0.25\" \\\n",
    "        SOLVER.WEIGHT_DECAY \"0.001\" \\\n",
    "        SEED ${seed} \\\n",
    "        MODEL.MODEL_ROOT \"${model_root}\" \\\n",
    "        DATA.DATAPATH \"${data_path}\" \\\n",
    "        OUTPUT_DIR \"${output_dir}/seed${seed}\" \n",
    "done\n",
    "\n",
    "# vtab-natural: sun397\n",
    "# base_lr = 25\n",
    "# lr = base_lr / 256 * cfg.DATA.BATCH_SIZE\n",
    "for seed in \"42\" \"44\" \"82\" \"100\" \"800\"; do\n",
    "    python train.py \\\n",
    "        --config-file configs/prompt/cub.yaml \\\n",
    "        MODEL.TYPE \"vit\" \\\n",
    "        DATA.BATCH_SIZE \"128\" \\\n",
    "        MODEL.PROMPT.NUM_TOKENS \"5\" \\\n",
    "        MODEL.PROMPT.DEEP \"True\" \\\n",
    "        MODEL.PROMPT.DROPOUT \"0.1\" \\\n",
    "        DATA.FEATURE \"sup_vitb16_imagenet21k\" \\\n",
    "        DATA.NAME \"vtab-sun397\" \\\n",
    "        DATA.NUMBER_CLASSES \"397\" \\\n",
    "        SOLVER.BASE_LR \"12.5\" \\\n",
    "        SOLVER.WEIGHT_DECAY \"0.0001\" \\\n",
    "        SOLVER.TOTAL_EPOCH \"100\" \\\n",
    "        SEED ${seed} \\\n",
    "        MODEL.MODEL_ROOT \"${model_root}\" \\\n",
    "        DATA.DATAPATH \"${data_path}\" \\\n",
    "        OUTPUT_DIR \"${output_dir}/seed${seed}\" \n",
    "done\n",
    "\n",
    "# vtab-specialized: vtab-eurosat\n",
    "# base_lr = 1\n",
    "# lr = base_lr / 256 * cfg.DATA.BATCH_SIZE\n",
    "for seed in \"42\" \"44\" \"82\" \"100\" \"800\"; do\n",
    "    python train.py \\\n",
    "        --config-file configs/prompt/cub.yaml \\\n",
    "        MODEL.TYPE \"vit\" \\\n",
    "        DATA.BATCH_SIZE \"64\" \\\n",
    "        MODEL.PROMPT.NUM_TOKENS \"100\" \\\n",
    "        MODEL.PROMPT.DEEP \"True\" \\\n",
    "        MODEL.PROMPT.DROPOUT \"0.1\" \\\n",
    "        DATA.FEATURE \"sup_vitb16_imagenet21k\" \\\n",
    "        DATA.NAME \"vtab-eurosat\" \\\n",
    "        DATA.NUMBER_CLASSES \"10\" \\\n",
    "        SOLVER.BASE_LR \"0.25\" \\\n",
    "        SOLVER.WEIGHT_DECAY \"0.001\" \\\n",
    "        SOLVER.TOTAL_EPOCH \"100\" \\\n",
    "        SEED ${seed} \\\n",
    "        MODEL.MODEL_ROOT \"${model_root}\" \\\n",
    "        DATA.DATAPATH \"${data_path}\" \\\n",
    "        OUTPUT_DIR \"${output_dir}/seed${seed}\" \n",
    "done"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63977e8e",
   "metadata": {},
   "source": [
    "## Get results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "14a527c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import pandas as pd\n",
    "\n",
    "from src.utils.vis_utils import get_df, average_df\n",
    "\n",
    "LOG_NAME = \"logs.txt\"\n",
    "pd.set_option('display.max_columns', None)\n",
    "pd.set_option('display.max_rows', None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a8a81bf6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "seed42: 100%|██████████| 3/3 [00:00<00:00, 241.66it/s]\n",
      "seed42: 100%|██████████| 3/3 [00:00<00:00, 240.13it/s]\n",
      "seed42: 100%|██████████| 3/3 [00:00<00:00, 229.10it/s]\n",
      "seed44: 100%|██████████| 3/3 [00:00<00:00, 223.63it/s]\n",
      "seed44: 100%|██████████| 3/3 [00:00<00:00, 265.73it/s]\n",
      "seed44: 100%|██████████| 3/3 [00:00<00:00, 232.97it/s]\n",
      "seed82: 100%|██████████| 3/3 [00:00<00:00, 221.82it/s]\n",
      "seed82: 100%|██████████| 3/3 [00:00<00:00, 224.52it/s]\n",
      "seed82: 100%|██████████| 3/3 [00:00<00:00, 258.04it/s]\n",
      "seed100: 100%|██████████| 3/3 [00:00<00:00, 222.09it/s]\n",
      "seed100: 100%|██████████| 3/3 [00:00<00:00, 215.91it/s]\n",
      "seed100: 100%|██████████| 3/3 [00:00<00:00, 236.46it/s]\n",
      "seed800: 100%|██████████| 3/3 [00:00<00:00, 236.60it/s]\n",
      "seed800: 100%|██████████| 3/3 [00:00<00:00, 229.92it/s]\n",
      "seed800: 100%|██████████| 3/3 [00:00<00:00, 252.79it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>feature</th>\n",
       "      <th>lr</th>\n",
       "      <th>wd</th>\n",
       "      <th>total_params</th>\n",
       "      <th>tuned_params</th>\n",
       "      <th>tuned / total (%)</th>\n",
       "      <th>batch_size</th>\n",
       "      <th>l-val_top1</th>\n",
       "      <th>l-test_top1</th>\n",
       "      <th>best_epoch</th>\n",
       "      <th>file</th>\n",
       "      <th>total_time</th>\n",
       "      <th>seed</th>\n",
       "      <th>type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>vtab-dmlab</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86724870</td>\n",
       "      <td>926214</td>\n",
       "      <td>1.0680</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>46.88</td>\n",
       "      <td>76 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 01:06:03</td>\n",
       "      <td>42</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>vtab-eurosat</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86727946</td>\n",
       "      <td>929290</td>\n",
       "      <td>1.0715</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>96.00</td>\n",
       "      <td>38 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 00:44:57</td>\n",
       "      <td>42</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>vtab-sun397</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>25.0</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>86150029</td>\n",
       "      <td>351373</td>\n",
       "      <td>0.4079</td>\n",
       "      <td>128</td>\n",
       "      <td>100.0</td>\n",
       "      <td>52.57</td>\n",
       "      <td>14 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 00:44:57</td>\n",
       "      <td>42</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>vtab-dmlab</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86724870</td>\n",
       "      <td>926214</td>\n",
       "      <td>1.0680</td>\n",
       "      <td>64</td>\n",
       "      <td>99.5</td>\n",
       "      <td>46.25</td>\n",
       "      <td>85 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 01:06:09</td>\n",
       "      <td>44</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>vtab-eurosat</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86727946</td>\n",
       "      <td>929290</td>\n",
       "      <td>1.0715</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>96.54</td>\n",
       "      <td>41 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 00:44:38</td>\n",
       "      <td>44</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>vtab-sun397</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>25.0</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>86150029</td>\n",
       "      <td>351373</td>\n",
       "      <td>0.4079</td>\n",
       "      <td>128</td>\n",
       "      <td>100.0</td>\n",
       "      <td>49.03</td>\n",
       "      <td>32 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 00:44:53</td>\n",
       "      <td>44</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>vtab-dmlab</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86724870</td>\n",
       "      <td>926214</td>\n",
       "      <td>1.0680</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>46.14</td>\n",
       "      <td>65 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 01:04:18</td>\n",
       "      <td>82</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>vtab-eurosat</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86727946</td>\n",
       "      <td>929290</td>\n",
       "      <td>1.0715</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>96.67</td>\n",
       "      <td>41 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 00:44:42</td>\n",
       "      <td>82</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>vtab-sun397</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>25.0</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>86150029</td>\n",
       "      <td>351373</td>\n",
       "      <td>0.4079</td>\n",
       "      <td>128</td>\n",
       "      <td>100.0</td>\n",
       "      <td>52.45</td>\n",
       "      <td>8 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 00:44:51</td>\n",
       "      <td>82</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>vtab-dmlab</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86724870</td>\n",
       "      <td>926214</td>\n",
       "      <td>1.0680</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>47.41</td>\n",
       "      <td>76 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 01:06:15</td>\n",
       "      <td>100</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>vtab-eurosat</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86727946</td>\n",
       "      <td>929290</td>\n",
       "      <td>1.0715</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>95.41</td>\n",
       "      <td>47 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 00:44:59</td>\n",
       "      <td>100</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>vtab-sun397</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>25.0</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>86150029</td>\n",
       "      <td>351373</td>\n",
       "      <td>0.4079</td>\n",
       "      <td>128</td>\n",
       "      <td>100.0</td>\n",
       "      <td>51.08</td>\n",
       "      <td>17 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 00:44:40</td>\n",
       "      <td>100</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>vtab-dmlab</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86724870</td>\n",
       "      <td>926214</td>\n",
       "      <td>1.0680</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>46.44</td>\n",
       "      <td>77 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 01:06:12</td>\n",
       "      <td>800</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>vtab-eurosat</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86727946</td>\n",
       "      <td>929290</td>\n",
       "      <td>1.0715</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>96.11</td>\n",
       "      <td>42 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 00:44:16</td>\n",
       "      <td>800</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>vtab-sun397</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>25.0</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>86150029</td>\n",
       "      <td>351373</td>\n",
       "      <td>0.4079</td>\n",
       "      <td>128</td>\n",
       "      <td>100.0</td>\n",
       "      <td>52.72</td>\n",
       "      <td>11 | 100</td>\n",
       "      <td>/fsx/menglin/experiments/2022prompt/output/rel...</td>\n",
       "      <td>0 days 01:01:28</td>\n",
       "      <td>800</td>\n",
       "      <td>VPT</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           data                 feature    lr      wd  total_params  \\\n",
       "2    vtab-dmlab  sup_vitb16_imagenet21k   1.0  0.0010      86724870   \n",
       "0  vtab-eurosat  sup_vitb16_imagenet21k   1.0  0.0010      86727946   \n",
       "1   vtab-sun397  sup_vitb16_imagenet21k  25.0  0.0001      86150029   \n",
       "2    vtab-dmlab  sup_vitb16_imagenet21k   1.0  0.0010      86724870   \n",
       "0  vtab-eurosat  sup_vitb16_imagenet21k   1.0  0.0010      86727946   \n",
       "1   vtab-sun397  sup_vitb16_imagenet21k  25.0  0.0001      86150029   \n",
       "2    vtab-dmlab  sup_vitb16_imagenet21k   1.0  0.0010      86724870   \n",
       "1  vtab-eurosat  sup_vitb16_imagenet21k   1.0  0.0010      86727946   \n",
       "0   vtab-sun397  sup_vitb16_imagenet21k  25.0  0.0001      86150029   \n",
       "2    vtab-dmlab  sup_vitb16_imagenet21k   1.0  0.0010      86724870   \n",
       "0  vtab-eurosat  sup_vitb16_imagenet21k   1.0  0.0010      86727946   \n",
       "1   vtab-sun397  sup_vitb16_imagenet21k  25.0  0.0001      86150029   \n",
       "2    vtab-dmlab  sup_vitb16_imagenet21k   1.0  0.0010      86724870   \n",
       "1  vtab-eurosat  sup_vitb16_imagenet21k   1.0  0.0010      86727946   \n",
       "0   vtab-sun397  sup_vitb16_imagenet21k  25.0  0.0001      86150029   \n",
       "\n",
       "   tuned_params  tuned / total (%)  batch_size  l-val_top1  l-test_top1  \\\n",
       "2        926214             1.0680          64       100.0        46.88   \n",
       "0        929290             1.0715          64       100.0        96.00   \n",
       "1        351373             0.4079         128       100.0        52.57   \n",
       "2        926214             1.0680          64        99.5        46.25   \n",
       "0        929290             1.0715          64       100.0        96.54   \n",
       "1        351373             0.4079         128       100.0        49.03   \n",
       "2        926214             1.0680          64       100.0        46.14   \n",
       "1        929290             1.0715          64       100.0        96.67   \n",
       "0        351373             0.4079         128       100.0        52.45   \n",
       "2        926214             1.0680          64       100.0        47.41   \n",
       "0        929290             1.0715          64       100.0        95.41   \n",
       "1        351373             0.4079         128       100.0        51.08   \n",
       "2        926214             1.0680          64       100.0        46.44   \n",
       "1        929290             1.0715          64       100.0        96.11   \n",
       "0        351373             0.4079         128       100.0        52.72   \n",
       "\n",
       "  best_epoch                                               file  \\\n",
       "2   76 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "0   38 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "1   14 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "2   85 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "0   41 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "1   32 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "2   65 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "1   41 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "0    8 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "2   76 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "0   47 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "1   17 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "2   77 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "1   42 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "0   11 | 100  /fsx/menglin/experiments/2022prompt/output/rel...   \n",
       "\n",
       "       total_time seed type  \n",
       "2 0 days 01:06:03   42  VPT  \n",
       "0 0 days 00:44:57   42  VPT  \n",
       "1 0 days 00:44:57   42  VPT  \n",
       "2 0 days 01:06:09   44  VPT  \n",
       "0 0 days 00:44:38   44  VPT  \n",
       "1 0 days 00:44:53   44  VPT  \n",
       "2 0 days 01:04:18   82  VPT  \n",
       "1 0 days 00:44:42   82  VPT  \n",
       "0 0 days 00:44:51   82  VPT  \n",
       "2 0 days 01:06:15  100  VPT  \n",
       "0 0 days 00:44:59  100  VPT  \n",
       "1 0 days 00:44:40  100  VPT  \n",
       "2 0 days 01:06:12  800  VPT  \n",
       "1 0 days 00:44:16  800  VPT  \n",
       "0 0 days 01:01:28  800  VPT  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "root = <MODEL_ROOT>\n",
    "df_list=[]\n",
    "for seed in [\"42\", \"44\", \"82\", \"100\", \"800\"]:\n",
    "#     model_type = f\"adapter_{r}\"\n",
    "    files = glob.glob(f\"{root}/seed{seed}/*/sup_vitb16_imagenet21k/*/*/{LOG_NAME}\")\n",
    "    for f in files:\n",
    "        df = get_df(files, f\"seed{seed}\", root, is_best=False, is_last=True)\n",
    "        if df is None:\n",
    "            continue\n",
    "        df[\"seed\"] = seed\n",
    "    df_list.append(df)\n",
    "\n",
    "df= pd.concat(df_list)\n",
    "df[\"type\"] = \"VPT\"\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1468233f",
   "metadata": {},
   "source": [
    "Take average of 5 runs for each dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "139f26c3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>feature</th>\n",
       "      <th>type</th>\n",
       "      <th>total_runs</th>\n",
       "      <th>l-test_top1</th>\n",
       "      <th>l-test_top1-std</th>\n",
       "      <th>lr</th>\n",
       "      <th>wd</th>\n",
       "      <th>total_params</th>\n",
       "      <th>tuned_params</th>\n",
       "      <th>tuned / total (%)</th>\n",
       "      <th>batch_size</th>\n",
       "      <th>l-val_top1</th>\n",
       "      <th>total_time</th>\n",
       "      <th>seed</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>vtab-dmlab</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>VPT</td>\n",
       "      <td>5</td>\n",
       "      <td>46.62</td>\n",
       "      <td>0.47</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86724870</td>\n",
       "      <td>926214</td>\n",
       "      <td>1.0680</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>0 days 01:06:03</td>\n",
       "      <td>42</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>vtab-eurosat</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>VPT</td>\n",
       "      <td>5</td>\n",
       "      <td>96.15</td>\n",
       "      <td>0.45</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>86727946</td>\n",
       "      <td>929290</td>\n",
       "      <td>1.0715</td>\n",
       "      <td>64</td>\n",
       "      <td>100.0</td>\n",
       "      <td>0 days 00:44:57</td>\n",
       "      <td>42</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>vtab-sun397</td>\n",
       "      <td>sup_vitb16_imagenet21k</td>\n",
       "      <td>VPT</td>\n",
       "      <td>5</td>\n",
       "      <td>51.57</td>\n",
       "      <td>1.40</td>\n",
       "      <td>25.0</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>86150029</td>\n",
       "      <td>351373</td>\n",
       "      <td>0.4079</td>\n",
       "      <td>128</td>\n",
       "      <td>100.0</td>\n",
       "      <td>0 days 00:44:57</td>\n",
       "      <td>42</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           data                 feature type  total_runs l-test_top1  \\\n",
       "0    vtab-dmlab  sup_vitb16_imagenet21k  VPT           5       46.62   \n",
       "2  vtab-eurosat  sup_vitb16_imagenet21k  VPT           5       96.15   \n",
       "1   vtab-sun397  sup_vitb16_imagenet21k  VPT           5       51.57   \n",
       "\n",
       "  l-test_top1-std    lr      wd  total_params  tuned_params  \\\n",
       "0            0.47   1.0  0.0010      86724870        926214   \n",
       "2            0.45   1.0  0.0010      86727946        929290   \n",
       "1            1.40  25.0  0.0001      86150029        351373   \n",
       "\n",
       "   tuned / total (%)  batch_size  l-val_top1      total_time seed  \n",
       "0             1.0680          64       100.0 0 days 01:06:03   42  \n",
       "2             1.0715          64       100.0 0 days 00:44:57   42  \n",
       "1             0.4079         128       100.0 0 days 00:44:57   42  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# LR represents the base learning rate, not the scaled one.\n",
    "f_df = average_df(df, metric_names=[\"l-test_top1\"], take_average=True)\n",
    "f_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22c53a16",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e49c4375",
   "metadata": {},
   "source": [
    "## tune*.py\n",
    "Tune vtab or fgvc datasets. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c55ea4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "# Tune VTAB-caltech101 with VPT:\n",
    "python tune_vtab.py \\\n",
    "    --train-type \"prompt\" \\\n",
    "    --config-file configs/prompt/cub.yaml \\\n",
    "    MODEL.TYPE \"vit\" \\\n",
    "    DATA.BATCH_SIZE \"128\" \\\n",
    "    MODEL.PROMPT.DEEP \"True\" \\\n",
    "    MODEL.PROMPT.DROPOUT \"0.1\" \\\n",
    "    MODEL.PROMPT.NUM_TOKENS \"10\" \\\n",
    "    DATA.FEATURE \"sup_vitb16_imagenet21k\" \\\n",
    "    DATA.NAME \"vtab-caltech101\" \\\n",
    "    DATA.NUMBER_CLASSES \"102\" \\\n",
    "    DATA.DATAPATH <DATA_PATH> \\\n",
    "    MODEL.MODEL_ROOT <MODEL_ROOT> \\\n",
    "    OUTPUT_DIR <OUTPUT_PATH> \n",
    "\n",
    "# Tune CUB with VPT:\n",
    "python tune_fgvc.py \\\n",
    "    --train-type \"prompt\" \\\n",
    "    --config-file configs/prompt/cub.yaml \\\n",
    "    MODEL.TYPE \"vit\" \\\n",
    "    DATA.BATCH_SIZE \"128\" \\\n",
    "    MODEL.PROMPT.DEEP \"True\" \\\n",
    "    MODEL.PROMPT.DROPOUT \"0.1\" \\\n",
    "    MODEL.PROMPT.NUM_TOKENS \"10\" \\\n",
    "    DATA.FEATURE \"sup_vitb16_imagenet21k\" \\\n",
    "    DATA.DATAPATH <DATA_PATH> \\\n",
    "    MODEL.MODEL_ROOT <MODEL_ROOT> \\\n",
    "    OUTPUT_DIR <OUTPUT_PATH> "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b7abeef",
   "metadata": {},
   "source": [
    "## Backbone choices"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87b5f6fe",
   "metadata": {},
   "source": [
    "- Swin-B\n",
    "\n",
    "```bash\n",
    "    MODEL.TYPE \"swin\" \\\n",
    "    DATA.FEATURE \"swinb_imagenet22k_224\"\n",
    "```\n",
    "\n",
    "- ResNet-50 (VPT with prompt location == pad)\n",
    "\n",
    "```bash\n",
    "    MODEL.TYPE \"resnet\" \\\n",
    "    DATA.FEATURE \"imagenet_sup_rn50\" \\\n",
    "    SOLVER.OPTIMIZER \"sgd\" \\\n",
    "    MODEL.PROMPT.LOCATION \"pad\" \\\n",
    "    MODEL.PROMPT.NUM_TOKENS \"5\" \n",
    "```\n",
    "\n",
    "- ConvNeXt-Base (VPT with prompt location == pad)\n",
    "\n",
    "```bash\n",
    "    MODEL.TYPE \"resnext\" \\\n",
    "    DATA.FEATURE \"imagenet22k_sup_rnx_base\" \\\n",
    "    MODEL.PROMPT.LOCATION \"pad\" \\\n",
    "    MODEL.PROMPT.NUM_TOKENS \"5\" \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e15670c",
   "metadata": {},
   "source": [
    "ViT with self-supervised pre-training objectives:\n",
    "    \n",
    "- MAE\n",
    "\n",
    "```bash\n",
    "MODEL.TYPE \"ssl-vit\" \\\n",
    "DATA.FEATURE \"mae_vitb16\"\n",
    "```\n",
    "\n",
    "- MoCo-v3\n",
    "\n",
    "```bash\n",
    "MODEL.TYPE \"ssl-vit\" \\\n",
    "DATA.FEATURE \"mocov3_vitb\" \n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b348f52a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "prompt",
   "language": "python",
   "name": "prompt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
